他の記事を見てみる
Jun 28 2018 / 23:02:47
プログラミング > 機械学習 > NN(ニューラルネットワーク) >
Keyword:

[機械学習] 複数のモデルの合成を平均をとってやってみた

執筆者ホームページ

𝖢𝗋𝖾𝖺𝗍𝗂𝗏𝖾 𝖦𝖯

Abstract

誤差逆伝播法が正しく動くようになったことによってネットワークを任意の入力に対して任意の出力が出るように訓練できるようになった。しかし、実際のニューラルネットワークは1つの入力と出力をマップするものではない。たくさんの出力に対してそれぞれ期待する入力を刷り込んでやらないと使い物にならない。

この記事ではこれをモデルの合成と呼ぶ。

Method

今回はモデルの合成のため、単純な方法でおそらく一番最初に思いつく方法であろう「平均を取る」という方法をとってみる。

例えば、XORを返すネットワークを作ろうと思った時に、まず誤差逆伝播によって以下のネットワークを生成する。

準備するネットワーク群
Net1: [0.1, 0.1] === NN ===> [0.1]
Net2: [0.1, 0.9] === NN ===> [0.9]
Net3: [0.9, 0.1] === NN ===> [0.9]
Net4: [0.9, 0.9] === NN ===> [0.1]

1,0にするとバイアスやら重みやらが爆発するので(シグモヰド函数のしぇい?)0.1, 0.9の微妙な値にとどめておく。

ネットワークのサイズは以下のようにした。

ネットワークのサイズ
2 Inputs → 8 Dence → 1 Output

諸パラメータ
学習率 0.1
バイアス学習率 0.1

Result

Net1:
0.1[]0/ 0.1[]0/ 
0.461225[0.0626076, 0.726138, ]-0.234287/ 0.454895[0.511846, 0.216285, ]-0.253726/ 0.459151[0.847629, 0.494977, ]-0.29802/ 0.430412[-0.0373044, 0.746381, ]-0.351076/ 0.436245[0.408468, 0.711115, ]-0.368372/ 0.42463[0.139749, 0.583252, ]-0.376096/ 0.399576[0.0489678, 0.277297, ]-0.43986/ 0.393289[0.111047, -0.199863, ]-0.424628/ 
0.0908007[0.280726, 0.142204, 0.521335, 0.501612, 0.00522412, -0.129593, 0.46911, -0.345765, ]-2.95204/ 

Net2:
0.1[]0/ 0.9[]0/ 
0.503215[0.146477, 0.274475, ]-0.248816/ 0.588823[0.0621211, 0.690219, ]-0.268308/ 0.534071[0.900686, 0.400186, ]-0.313739/ 0.584462[0.689173, 0.645649, ]-0.308883/ 0.32696[0.377266, -0.346023, ]-0.448273/ 0.379205[0.349677, -0.016169, ]-0.513341/ 0.489395[0.0438794, 0.492251, ]-0.489839/ 0.517847[0.891157, 0.53516, ]-0.49934/ 
0.819289[0.0348507, 0.88743, 0.499484, 0.248255, 0.924724, 0.774956, 0.322028, 0.344511, ]-0.372613/ 

Net3:
0.9[]0/ 0.1[]0/ 
0.531983[0.183554, 0.973735, ]-0.134466/ 0.536353[0.24005, 0.884855, ]-0.158863/ 0.607243[0.577, 0.858314, ]-0.169395/ 0.589007[0.602019, 0.182352, ]-0.200188/ 0.50653[0.263784, 0.112243, ]-0.22251/ 0.512342[0.218085, 0.874374, ]-0.234336/ 0.6097[0.653418, 0.843184, ]-0.226344/ 0.570578[0.52439, 0.633861, ]-0.251128/ 
0.87136[0.490456, 0.798993, 0.67661, 0.780432, 0.0949463, 0.2536, 0.37344, 0.279007, ]-0.211869/ 

Net4:
0.9[]0/ 0.9[]0/ 
0.536251[0.545349, 0.30968, ]-0.624268/ 0.333379[-0.115768, 0.248922, ]-0.81278/ 0.266111[-0.287083, 0.114366, ]-0.858998/ 0.254287[0.0547503, -0.294857, ]-0.859783/ 0.338534[0.461304, -0.279614, ]-0.833353/ 0.227082[-0.410675, 0.0800331, ]-0.927282/ 0.19593[-0.164999, -0.328325, ]-0.967938/ 0.493162[0.627668, 0.133931, ]-0.712795/ 
0.111399[-0.0478862, 0.0449693, -0.233222, -0.176751, 0.304124, -0.222535, -0.0309973, 0.121029, ]-2.06487/ 

平均を取ってできたXORモデル:
0[]0/ 0[]0/ 
0[0.173263, 0.199853, ]-0.378459/ 0[0.243409, 0.260228, ]-0.447822/ 0[0.289539, 0.122841, ]-0.490845/ 0[0.115915, 0.350059, ]-0.428634/ 0[0.0149099, 0.0515775, ]-0.542516/ 0[0.308048, 0.464876, ]-0.446924/ 0[0.204587, 0.125914, ]-0.49599/ 0[-0.0283544, 0.344654, ]-0.638818/ 
0[0.20507, 0.197326, 0.341477, 0.312974, 0.371548, 0.0730721, 0.221641, 0.527845, ]-1.3484/ 

それぞれの値を当てはめて順伝播してみる:
0[]0/ 0[]0/ 
0.423003[0.234497, 0.571007, ]-0.310459/ 0.407715[0.174562, 0.51007, ]-0.373419/ 0.398903[0.509558, 0.466961, ]-0.410038/ 0.39413[0.327159, 0.319881, ]-0.429983/ 0.38506[0.377705, 0.0494304, ]-0.468127/ 0.374546[0.074209, 0.380373, ]-0.512764/ 0.370285[0.145316, 0.321101, ]-0.530995/ 0.384149[0.538566, 0.275772, ]-0.471973/ 
0.373951[0.189537, 0.468399, 0.366052, 0.338387, 0.332255, 0.169107, 0.283395, 0.0996955, ]-1.40035/ 

1[]0/ 0[]0/ 
0.481018[0.234497, 0.571007, ]-0.310459/ 0.450449[0.174562, 0.51007, ]-0.373419/ 0.524859[0.509558, 0.466961, ]-0.410038/ 0.474317[0.327159, 0.319881, ]-0.429983/ 0.47741[0.377705, 0.0494304, ]-0.468127/ 0.392085[0.074209, 0.380373, ]-0.512764/ 0.404758[0.145316, 0.321101, ]-0.530995/ 0.516642[0.538566, 0.275772, ]-0.471973/ 
0.412311[0.189537, 0.468399, 0.366052, 0.338387, 0.332255, 0.169107, 0.283395, 0.0996955, ]-1.40035/ 

0[]0/ 1[]0/ 
0.564771[0.234497, 0.571007, ]-0.310459/ 0.53411[0.174562, 0.51007, ]-0.373419/ 0.514227[0.509558, 0.466961, ]-0.410038/ 0.472502[0.327159, 0.319881, ]-0.429983/ 0.396829[0.377705, 0.0494304, ]-0.468127/ 0.46695[0.074209, 0.380373, ]-0.512764/ 0.447718[0.145316, 0.321101, ]-0.530995/ 0.451107[0.538566, 0.275772, ]-0.471973/ 
0.422544[0.189537, 0.468399, 0.366052, 0.338387, 0.332255, 0.169107, 0.283395, 0.0996955, ]-1.40035/ 

1[]0/ 1[]0/ 
0.621294[0.234497, 0.571007, ]-0.310459/ 0.577181[0.174562, 0.51007, ]-0.373419/ 0.637951[0.509558, 0.466961, ]-0.410038/ 0.554052[0.327159, 0.319881, ]-0.429983/ 0.489754[0.377705, 0.0494304, ]-0.468127/ 0.485459[0.074209, 0.380373, ]-0.512764/ 0.483861[0.145316, 0.321101, ]-0.530995/ 0.584765[0.538566, 0.275772, ]-0.471973/ 
0.462323[0.189537, 0.468399, 0.366052, 0.338387, 0.332255, 0.169107, 0.283395, 0.0996955, ]-1.40035/

まとめ

うーんこのって感じですね。

やはり平均を取るだけじゃ駄目なような気がプンプンしてますね。次はbackprop内に取り込めないか考えてみます。ここら辺を織り合わせて学習してみるということ。

ソースはここで公開してますので自由に実行して下しあ。

ノシ